{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cqCt_GhvCnwY"
      },
      "source": [
        "# VQ-VAE training example\n",
        "\n",
        "Demonstration of how to train the model specified in https://arxiv.org/abs/1711.00937, using Haiku / JAX.\n",
        "\n",
        "On Mac and Linux, simply execute each cell in turn."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "O_8gPxwAq_3W"
      },
      "outputs": [],
      "source": [
        "# Uncomment the line below if running on colab.research.google.com\n",
        "# !pip install dm-haiku optax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "95YuC82P35Of"
      },
      "outputs": [],
      "source": [
        "import haiku as hk\n",
        "import jax\n",
        "import optax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "tf.enable_v2_behavior()\n",
        "\n",
        "print(\"JAX version {}\".format(jax.__version__))\n",
        "print(\"Haiku version {}\".format(hk.__version__))\n",
        "print(\"TF version {}\".format(tf.__version__))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DT8fKmqQC35h"
      },
      "source": [
        "# Download Cifar10 data\n",
        "This requires a connection to the internet and will download ~160MB.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "pobKFxUGBD6-"
      },
      "outputs": [],
      "source": [
        "cifar10 = tfds.as_numpy(tfds.load(\"cifar10\", split=\"train+test\", batch_size=-1))\n",
        "del cifar10[\"id\"], cifar10[\"label\"]\n",
        "jax.tree.map(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lUgvEhfJyQLZ"
      },
      "source": [
        "# Load the data into Numpy\n",
        "We compute the variance of the whole training set to normalise the Mean Squared Error below.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9C-V2D6RSQwl"
      },
      "outputs": [],
      "source": [
        "train_data_dict = jax.tree.map(lambda x: x[:40000], cifar10)\n",
        "valid_data_dict = jax.tree.map(lambda x: x[40000:50000], cifar10)\n",
        "test_data_dict = jax.tree.map(lambda x: x[50000:], cifar10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cIRl2ZtxoKNz"
      },
      "outputs": [],
      "source": [
        "def cast_and_normalise_images(data_dict):\n",
        "  \"\"\"Convert images to floating point with the range [-0.5, 0.5]\"\"\"\n",
        "  data_dict['image'] = (tf.cast(data_dict['image'], tf.float32) / 255.0) - 0.5\n",
        "  return data_dict\n",
        "\n",
        "train_data_variance = np.var(train_data_dict['image'] / 255.0)\n",
        "print('train data variance: %s' % train_data_variance)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Jse__pEBAkvI"
      },
      "source": [
        "# Encoder \u0026 Decoder Architecture\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1gwD36Vr6KqA"
      },
      "outputs": [],
      "source": [
        "class ResidualStack(hk.Module):\n",
        "  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,\n",
        "               name=None):\n",
        "    super(ResidualStack, self).__init__(name=name)\n",
        "    self._num_hiddens = num_hiddens\n",
        "    self._num_residual_layers = num_residual_layers\n",
        "    self._num_residual_hiddens = num_residual_hiddens\n",
        "\n",
        "    self._layers = []\n",
        "    for i in range(num_residual_layers):\n",
        "      conv3 = hk.Conv2D(\n",
        "          output_channels=num_residual_hiddens,\n",
        "          kernel_shape=(3, 3),\n",
        "          stride=(1, 1),\n",
        "          name=\"res3x3_%d\" % i)\n",
        "      conv1 = hk.Conv2D(\n",
        "          output_channels=num_hiddens,\n",
        "          kernel_shape=(1, 1),\n",
        "          stride=(1, 1),\n",
        "          name=\"res1x1_%d\" % i)\n",
        "      self._layers.append((conv3, conv1))\n",
        "\n",
        "  def __call__(self, inputs):\n",
        "    h = inputs\n",
        "    for conv3, conv1 in self._layers:\n",
        "      conv3_out = conv3(jax.nn.relu(h))\n",
        "      conv1_out = conv1(jax.nn.relu(conv3_out))\n",
        "      h += conv1_out\n",
        "    return jax.nn.relu(h)  # Resnet V1 style\n",
        "\n",
        "\n",
        "class Encoder(hk.Module):\n",
        "  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,\n",
        "               name=None):\n",
        "    super(Encoder, self).__init__(name=name)\n",
        "    self._num_hiddens = num_hiddens\n",
        "    self._num_residual_layers = num_residual_layers\n",
        "    self._num_residual_hiddens = num_residual_hiddens\n",
        "\n",
        "    self._enc_1 = hk.Conv2D(\n",
        "        output_channels=self._num_hiddens // 2,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"enc_1\")\n",
        "    self._enc_2 = hk.Conv2D(\n",
        "        output_channels=self._num_hiddens,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"enc_2\")\n",
        "    self._enc_3 = hk.Conv2D(\n",
        "        output_channels=self._num_hiddens,\n",
        "        kernel_shape=(3, 3),\n",
        "        stride=(1, 1),\n",
        "        name=\"enc_3\")\n",
        "    self._residual_stack = ResidualStack(\n",
        "        self._num_hiddens,\n",
        "        self._num_residual_layers,\n",
        "        self._num_residual_hiddens)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    h = jax.nn.relu(self._enc_1(x))\n",
        "    h = jax.nn.relu(self._enc_2(h))\n",
        "    h = jax.nn.relu(self._enc_3(h))\n",
        "    return self._residual_stack(h)\n",
        "\n",
        "\n",
        "class Decoder(hk.Module):\n",
        "  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,\n",
        "               name=None):\n",
        "    super(Decoder, self).__init__(name=name)\n",
        "    self._num_hiddens = num_hiddens\n",
        "    self._num_residual_layers = num_residual_layers\n",
        "    self._num_residual_hiddens = num_residual_hiddens\n",
        "\n",
        "    self._dec_1 = hk.Conv2D(\n",
        "        output_channels=self._num_hiddens,\n",
        "        kernel_shape=(3, 3),\n",
        "        stride=(1, 1),\n",
        "        name=\"dec_1\")\n",
        "    self._residual_stack = ResidualStack(\n",
        "        self._num_hiddens,\n",
        "        self._num_residual_layers,\n",
        "        self._num_residual_hiddens)\n",
        "    self._dec_2 = hk.Conv2DTranspose(\n",
        "        output_channels=self._num_hiddens // 2,\n",
        "        # output_shape=None,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"dec_2\")\n",
        "    self._dec_3 = hk.Conv2DTranspose(\n",
        "        output_channels=3,\n",
        "        # output_shape=None,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"dec_3\")\n",
        "    \n",
        "  def __call__(self, x):\n",
        "    h = self._dec_1(x)\n",
        "    h = self._residual_stack(h)\n",
        "    h = jax.nn.relu(self._dec_2(h))\n",
        "    x_recon = self._dec_3(h)\n",
        "    return x_recon\n",
        "    \n",
        "\n",
        "class VQVAEModel(hk.Module):\n",
        "  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1, \n",
        "               data_variance, name=None):\n",
        "    super(VQVAEModel, self).__init__(name=name)\n",
        "    self._encoder = encoder\n",
        "    self._decoder = decoder\n",
        "    self._vqvae = vqvae\n",
        "    self._pre_vq_conv1 = pre_vq_conv1\n",
        "    self._data_variance = data_variance\n",
        "\n",
        "  def __call__(self, inputs, is_training):\n",
        "    z = self._pre_vq_conv1(self._encoder(inputs))\n",
        "    vq_output = self._vqvae(z, is_training=is_training)\n",
        "    x_recon = self._decoder(vq_output['quantize'])\n",
        "    recon_error = jnp.mean((x_recon - inputs) ** 2) / self._data_variance\n",
        "    loss = recon_error + vq_output['loss']\n",
        "    return {\n",
        "        'z': z,\n",
        "        'x_recon': x_recon,\n",
        "        'loss': loss,\n",
        "        'recon_error': recon_error,\n",
        "        'vq_output': vq_output,\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FF7WaOn-s7En"
      },
      "source": [
        "# Build Model and train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "owGEoOkO4ttk"
      },
      "outputs": [],
      "source": [
        "# Set hyper-parameters.\n",
        "batch_size = 32\n",
        "image_size = 32\n",
        "\n",
        "# 100k steps should take \u003c 30 minutes on a modern (\u003e= 2017) GPU.\n",
        "num_training_updates = 100000\n",
        "\n",
        "num_hiddens = 128\n",
        "num_residual_hiddens = 32\n",
        "num_residual_layers = 2\n",
        "# These hyper-parameters define the size of the model (number of parameters and layers).\n",
        "# The hyper-parameters in the paper were (For ImageNet):\n",
        "# batch_size = 128\n",
        "# image_size = 128\n",
        "# num_hiddens = 128\n",
        "# num_residual_hiddens = 32\n",
        "# num_residual_layers = 2\n",
        "\n",
        "# This value is not that important, usually 64 works.\n",
        "# This will not change the capacity in the information-bottleneck.\n",
        "embedding_dim = 64\n",
        "\n",
        "# The higher this value, the higher the capacity in the information bottleneck.\n",
        "num_embeddings = 512\n",
        "\n",
        "# commitment_cost should be set appropriately. It's often useful to try a couple\n",
        "# of values. It mostly depends on the scale of the reconstruction cost\n",
        "# (log p(x|z)). So if the reconstruction cost is 100x higher, the\n",
        "# commitment_cost should also be multiplied with the same amount.\n",
        "commitment_cost = 0.25\n",
        "\n",
        "# Use EMA updates for the codebook (instead of the Adam optimizer).\n",
        "# This typically converges faster, and makes the model less dependent on choice\n",
        "# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was\n",
        "# developed afterwards). See Appendix of the paper for more details.\n",
        "vq_use_ema = True\n",
        "\n",
        "# This is only used for EMA updates.\n",
        "decay = 0.99\n",
        "\n",
        "learning_rate = 3e-4\n",
        "\n",
        "\n",
        "# # Data Loading.\n",
        "train_dataset = tfds.as_numpy(\n",
        "    tf.data.Dataset.from_tensor_slices(train_data_dict)\n",
        "    .map(cast_and_normalise_images)\n",
        "    .shuffle(10000)\n",
        "    .repeat(-1)  # repeat indefinitely\n",
        "    .batch(batch_size, drop_remainder=True)\n",
        "    .prefetch(-1))\n",
        "valid_dataset = tfds.as_numpy(\n",
        "    tf.data.Dataset.from_tensor_slices(valid_data_dict)\n",
        "    .map(cast_and_normalise_images)\n",
        "    .repeat(1)  # 1 epoch\n",
        "    .batch(batch_size)\n",
        "    .prefetch(-1))\n",
        "\n",
        "# # Build modules.\n",
        "def forward(data, is_training):\n",
        "  encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)\n",
        "  decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)\n",
        "  pre_vq_conv1 = hk.Conv2D(\n",
        "      output_channels=embedding_dim,\n",
        "      kernel_shape=(1, 1),\n",
        "      stride=(1, 1),\n",
        "      name=\"to_vq\")\n",
        "\n",
        "  if vq_use_ema:\n",
        "    vq_vae = hk.nets.VectorQuantizerEMA(\n",
        "        embedding_dim=embedding_dim,\n",
        "        num_embeddings=num_embeddings,\n",
        "        commitment_cost=commitment_cost,\n",
        "        decay=decay)\n",
        "  else:\n",
        "    vq_vae = hk.nets.VectorQuantizer(\n",
        "        embedding_dim=embedding_dim,\n",
        "        num_embeddings=num_embeddings,\n",
        "        commitment_cost=commitment_cost)\n",
        "    \n",
        "  model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,\n",
        "                     data_variance=train_data_variance)\n",
        "\n",
        "  return model(data['image'], is_training)\n",
        "\n",
        "forward = hk.transform_with_state(forward)\n",
        "optimizer = optax.adam(learning_rate)\n",
        "\n",
        "@jax.jit\n",
        "def train_step(params, state, opt_state, data):\n",
        "  def adapt_forward(params, state, data):\n",
        "    # Pack model output and state together.\n",
        "    model_output, state = forward.apply(params, state, None, data, is_training=True)\n",
        "    loss = model_output['loss']\n",
        "    return loss, (model_output, state)\n",
        "\n",
        "  grads, (model_output, state) = (\n",
        "      jax.grad(adapt_forward, has_aux=True)(params, state, data))\n",
        "\n",
        "  updates, opt_state = optimizer.update(grads, opt_state)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "\n",
        "  return params, state, opt_state, model_output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "d7edmrBbJZy-"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "\n",
        "train_losses = []\n",
        "train_recon_errors = []\n",
        "train_perplexities = []\n",
        "train_vqvae_loss = []\n",
        "\n",
        "rng = jax.random.PRNGKey(42)\n",
        "train_dataset_iter = iter(train_dataset)\n",
        "params, state = forward.init(rng, next(train_dataset_iter), is_training=True)\n",
        "opt_state = optimizer.init(params)\n",
        "\n",
        "for step in range(1, num_training_updates + 1):\n",
        "  data = next(train_dataset_iter)\n",
        "  params, state, opt_state, train_results = (\n",
        "      train_step(params, state, opt_state, data))\n",
        "\n",
        "  train_results = jax.device_get(train_results)\n",
        "  train_losses.append(train_results['loss'])\n",
        "  train_recon_errors.append(train_results['recon_error'])\n",
        "  train_perplexities.append(train_results['vq_output']['perplexity'])\n",
        "  train_vqvae_loss.append(train_results['vq_output']['loss'])\n",
        "\n",
        "  if step % 100 == 0:\n",
        "    print(f'[Step {step}/{num_training_updates}] ' + \n",
        "          ('train loss: %f ' % np.mean(train_losses[-100:])) +\n",
        "          ('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +\n",
        "          ('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +\n",
        "          ('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "m2hNyAnhs-1f"
      },
      "source": [
        "# Plot loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2vo-lDyRomKD"
      },
      "outputs": [],
      "source": [
        "f = plt.figure(figsize=(16,8))\n",
        "ax = f.add_subplot(1,2,1)\n",
        "ax.plot(train_recon_errors)\n",
        "ax.set_yscale('log')\n",
        "ax.set_title('NMSE.')\n",
        "\n",
        "ax = f.add_subplot(1,2,2)\n",
        "ax.plot(train_perplexities)\n",
        "ax.set_title('Average codebook usage (perplexity).')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Lyj1CCKptCZz"
      },
      "source": [
        "# View reconstructions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rM9zj7ZiPZBG"
      },
      "outputs": [],
      "source": [
        "# Reconstructions\n",
        "train_batch = next(iter(train_dataset))\n",
        "valid_batch = next(iter(valid_dataset))\n",
        "\n",
        "# Put data through the model with is_training=False, so that in the case of \n",
        "# using EMA the codebook is not updated.\n",
        "train_reconstructions = forward.apply(params, state, rng, train_batch, is_training=False)[0]['x_recon']\n",
        "valid_reconstructions = forward.apply(params, state, rng, valid_batch, is_training=False)[0]['x_recon']\n",
        "\n",
        "\n",
        "def convert_batch_to_image_grid(image_batch):\n",
        "  reshaped = (image_batch.reshape(4, 8, 32, 32, 3)\n",
        "              .transpose([0, 2, 1, 3, 4])\n",
        "              .reshape(4 * 32, 8 * 32, 3))\n",
        "  return reshaped + 0.5\n",
        "\n",
        "\n",
        "\n",
        "f = plt.figure(figsize=(16,8))\n",
        "ax = f.add_subplot(2,2,1)\n",
        "ax.imshow(convert_batch_to_image_grid(train_batch['image']),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('training data originals')\n",
        "plt.axis('off')\n",
        "\n",
        "ax = f.add_subplot(2,2,2)\n",
        "ax.imshow(convert_batch_to_image_grid(train_reconstructions),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('training data reconstructions')\n",
        "plt.axis('off')\n",
        "\n",
        "ax = f.add_subplot(2,2,3)\n",
        "ax.imshow(convert_batch_to_image_grid(valid_batch['image']),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('validation data originals')\n",
        "plt.axis('off')\n",
        "\n",
        "ax = f.add_subplot(2,2,4)\n",
        "ax.imshow(convert_batch_to_image_grid(valid_reconstructions),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('validation data reconstructions')\n",
        "plt.axis('off')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "private"
      },
      "name": "JAX VQ-VAE training example",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
